{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torchview import draw_graph\n",
    "from  generator.parser import Parser\n",
    "import generator.bricks as bricks\n",
    "import sys\n",
    "import time\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Parser: инициализация ...\n",
      "ERROR: В модуле не определен входной блок\n",
      "Время генерации моделей: 0.053220510482788086\n"
     ]
    }
   ],
   "source": [
    "# Тест создания моделей из выражения строки\n",
    "device = \"cuda\"\n",
    "parser = Parser() # создаем парсер\n",
    "s1 = \"output={{@4->relu+@8->relu}^2}%2->@16->softmax->linear(5);\"\n",
    "s2 = \"output={{@16->relu+@16->sigmoid}^4}%8->@16;\"\n",
    "s3 = \"output={{@64->relu}^64}%64;\"\n",
    "s4 = \"output = linear(5) -> softmax;\"\n",
    "s5 = \"output = {@5->@20 + @10->@20 + @20} -> softmax;\"\n",
    "# s6 = \"\"\"\n",
    "# x = @64;            # x - linear 64 нейрона\n",
    "# y = x + @64;        # y - параллельно соединены x и модуль из 64 нейронов\n",
    "# z = x -> y;         # z - x последовательно соединен с y\n",
    "# w = @8 ^ 4;         # w - 4 слоя по 16 нейронов последовательно соединены\n",
    "# a = x % 2;          # a - параллельно соединены два модуля x\n",
    "# output = z -> w -> a -> {{@8 -> relu + @8 -> relu} ^ 2} % 2 -> @16 -> softmax;\n",
    "# \"\"\"\n",
    "s6 = \"\"\"\n",
    "    y = @64 + @64;          # y - параллельно соединены x и модуль из 64 нейронов\n",
    "    z = @8 -> y;            # z - x последовательно соединен с y\n",
    "    w = @8 ^ 4;             # w - 4 слоя по 8 нейронов последовательно соединены\n",
    "    a = {@16 + @16} % 2;    # a - параллельно соединены два модуля x\n",
    "    output = z -> w -> a -> {{@8 -> relu + @8 -> relu} ^ 2} % 2 -> @16 -> softmax;\n",
    "\"\"\"\n",
    "start_time = time.time()\n",
    "# создаем модель из строки s5\n",
    "modules = parser.from_str(s6)\n",
    "# # Создаем модули их json-файла\n",
    "# modules = parser.from_json('nntest.json')\n",
    "model = modules['output'].to(device)\n",
    "end_time = time.time()\n",
    "print(f\"Время генерации моделей: {end_time-start_time}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[0.0510, 0.0575, 0.0555, 0.0638, 0.0619, 0.0682, 0.0859, 0.0619, 0.0510,\n",
      "         0.0963, 0.0539, 0.0531, 0.0558, 0.0767, 0.0598, 0.0479]],\n",
      "       device='cuda:0', grad_fn=<SoftmaxBackward0>)\n",
      "{{{{{linear(8)->{linear(64)+linear(64)}}->linear(8)^4}->{linear(16)+linear(16)}%2}->{{linear(8)->relu}+{linear(8)->relu}}^2%2}->linear(16)}\n",
      "{{linear(8)->relu}+{linear(8)->relu}}^2%2\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'y': {linear(64)+linear(64)},\n",
       " 'z': {linear(8)->{linear(64)+linear(64)}},\n",
       " 'w': linear(8)^4,\n",
       " 'a': {linear(16)+linear(16)}%2,\n",
       " 'output': {{{{{{linear(8)->{linear(64)+linear(64)}}->linear(8)^4}->{linear(16)+linear(16)}%2}->{{linear(8)->relu}+{linear(8)->relu}}^2%2}->linear(16)}->softmax}}"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Тестируем работу модели на тестовом тензоре\n",
    "x = torch.randn(1, 10).to(device)\n",
    "output2 = model.to(device)(x)\n",
    "print(output2)\n",
    "# Получим один из элементов модели\n",
    "print(modules['output'].left)\n",
    "print(modules['output'].left.left.right)\n",
    "modules"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1, 10)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tuple(x.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Параметров: 550, время создания:  0.009173870086669922\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 20])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Подсчитаем размер модели\n",
    "params_count = sum(p.numel() for p in modules['output'].parameters())\n",
    "print(f\"Параметров: {params_count}, время создания: \", end_time - start_time)\n",
    "output2.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'{{{{linear(5)->linear(20)}+{linear(10)->linear(20)}}+linear(20)}->softmax}'"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Представление модели в виде выражения\n",
    "model.expr_str(expand=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Нарисуем модель\n",
    "input_size = (1, 5)\n",
    "pic_path = './pic/'\n",
    "model_graph = draw_graph(\n",
    "    model, \n",
    "    input_size=input_size,\n",
    "    graph_name='test',\n",
    "    graph_dir='LR',\n",
    "    depth=8,\n",
    "    hide_inner_tensors=True,\n",
    "    hide_module_functions=True,\n",
    "    save_graph=True,\n",
    "    expand_nested=True,\n",
    "    show_shapes=True,\n",
    "    filename='test',\n",
    "    directory=pic_path\n",
    ")\n",
    "\n",
    "# model_graph.resize_graph(scale=3)\n",
    "# model_graph.visual_graph.view('./pic/test',)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[-0.1565,  0.3500, -0.3436,  0.3257,  0.0636,  0.4658,  0.6467,  0.0350,\n",
      "         -0.0691, -0.6981,  1.4110, -0.1775,  0.6306,  0.1815,  0.0094, -0.4796,\n",
      "          0.3549, -0.6143,  0.0438, -0.2355]], device='cuda:0',\n",
      "       grad_fn=<AddmmBackward0>)\n",
      "tensor([[ 0.9092, -0.6238,  0.1989,  0.1603, -1.8952,  0.2508,  0.0374, -0.5332,\n",
      "          0.0413,  0.5759,  0.3830,  0.1071, -0.4529,  0.4732, -0.6295,  0.4190,\n",
      "         -0.6940,  0.9568,  0.5844, -0.3995, -0.5055, -0.3539,  0.0489,  0.9234,\n",
      "          0.3009]], device='cuda:0', grad_fn=<CatBackward0>)\n"
     ]
    }
   ],
   "source": [
    "x = torch.randn(input_size).to(device)\n",
    "# Извлечем элемент left.right из модели\n",
    "chunk = modules['output'].left.right\n",
    "print(chunk(x))\n",
    "model_graph = draw_graph(\n",
    "    chunk, \n",
    "    input_size=input_size,\n",
    "    graph_name='left.right',\n",
    "    graph_dir='LR',\n",
    "    depth=8,\n",
    "    hide_inner_tensors=True,\n",
    "    hide_module_functions=True,\n",
    "    save_graph=True,\n",
    "    expand_nested=True,\n",
    "    show_shapes=True,\n",
    "    filename='left.right',\n",
    "    directory='./pic/'\n",
    ")\n",
    "# Разделим на две части, соединенные параллельно\n",
    "left, right = chunk.decompose()\n",
    "new_chunk = bricks.Connector(left, right).to(\"cuda\")\n",
    "print(new_chunk(x))\n",
    "model_graph = draw_graph(\n",
    "    new_chunk, \n",
    "    input_size=input_size,\n",
    "    graph_name='decompose',\n",
    "    graph_dir='LR',\n",
    "    depth=8,\n",
    "    hide_inner_tensors=True,\n",
    "    hide_module_functions=True,\n",
    "    save_graph=True,\n",
    "    expand_nested=True,\n",
    "    show_shapes=True,\n",
    "    filename='decompose',\n",
    "    directory='./pic/'\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
